import math
import torch
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
                            batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm(x)           # (batch, timesteps, hidden_size)
        out = self.dropout(out)
        out = self.fc(out)              # (batch, timesteps, 1)
        return out.squeeze(-1)          # → (batch, timesteps)

class StackedLSTMWithLayerNorm(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=2, dropout=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Create N LSTMCell layers + LayerNorm layers
        self.cells = nn.ModuleList()
        self.norms = nn.ModuleList()
        for layer in range(num_layers):
            
            in_size = input_size if layer == 0 else hidden_size
            self.cells.append(nn.LSTMCell(in_size, hidden_size))
            self.norms.append(nn.LayerNorm(hidden_size))

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, 1)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        # Initialize hidden/cell states for all layers
        h = [torch.zeros(batch_size, self.hidden_size, device=x.device) for _ in range(self.num_layers)]
        c = [torch.zeros(batch_size, self.hidden_size, device=x.device) for _ in range(self.num_layers)]

        outputs = []
        for t in range(seq_len):
            layer_input = x[:, t, :]

            for layer in range(self.num_layers):
                h[layer], c[layer] = self.cells[layer](layer_input, (h[layer], c[layer]))
                h[layer] = self.norms[layer](h[layer])
                layer_input = h[layer]  # input to next layer

            out = self.dropout(h[-1])
            out = self.fc(out)
            outputs.append(out)

        outputs = torch.stack(outputs, dim=1)
        return outputs.squeeze(-1)


# Chrono initialization helpers for LSTM

def _chrono_sample_bf(hidden_size: int, Tmax: int, Tmin: float, device):
    """
    Sample chrono forget-gate bias b_f = log(u), where u ~ Uniform[Tmin, Tmax].
    """
    u = (Tmax - Tmin) * torch.rand(hidden_size, device=device) + Tmin
    return torch.log(u)

def apply_chrono_init_lstm(lstm_module: nn.LSTM, Tmax: int, Tmin: float = 1.0, debug: bool = False) -> nn.LSTM:
    """
    Apply chrono initialization to a PyTorch nn.LSTM module (all layers).
    PyTorch LSTM gate order is (i, f, g, o).
    We set:
      b_f = log(u),  u ~ U[Tmin, Tmax]
      b_i = -b_f
    and leave other gate biases unchanged.

    We put the chrono bias into bias_ih and zero the corresponding bias_hh slices.
    """
    if Tmax is None:
        raise ValueError("Tmax must be provided for chrono initialization")

    with torch.no_grad():
        num_layers = getattr(lstm_module, "num_layers", 1)
        bidir = getattr(lstm_module, "bidirectional", False)

        for layer in range(num_layers):
            suffixes = ["", "_reverse"] if bidir else [""]

            for suffix in suffixes:
                b_ih_name = f"bias_ih_l{layer}{suffix}"
                b_hh_name = f"bias_hh_l{layer}{suffix}"

                b_ih = getattr(lstm_module, b_ih_name, None)
                b_hh = getattr(lstm_module, b_hh_name, None)
                if b_ih is None or b_hh is None:
                    continue

                H = b_ih.shape[0] // 4
                device = b_ih.device

                bf = _chrono_sample_bf(H, Tmax=Tmax, Tmin=Tmin, device=device)

                # gate slices for (i,f,g,o)
                i0, i1 = 0, H
                f0, f1 = H, 2 * H

                b_ih[i0:i1].copy_(-bf)   # input gate
                b_ih[f0:f1].copy_(bf)    # forget gate

                # making sure the effective bias is what we want
                b_hh[i0:i1].zero_()
                b_hh[f0:f1].zero_()

                if debug and layer == 0 and suffix == "":
                    print("[Chrono-LSTM] b_f (first 5):", bf[:5].cpu().numpy())
                    print("[Chrono-LSTM] b_i (first 5):", (-bf)[:5].cpu().numpy())

    return lstm_module

def apply_chrono_init_lstmcell(cell: nn.LSTMCell, Tmax: int, Tmin: float = 1.0, debug: bool = False) -> nn.LSTMCell:
    """
    Apply chrono initialization to a single nn.LSTMCell (gate order i,f,g,o).
    """
    if Tmax is None:
        raise ValueError("Tmax must be provided for chrono initialization")

    with torch.no_grad():
        H = cell.bias_ih.shape[0] // 4
        device = cell.bias_ih.device

        bf = _chrono_sample_bf(H, Tmax=Tmax, Tmin=Tmin, device=device)

        i0, i1 = 0, H
        f0, f1 = H, 2 * H

        cell.bias_ih[i0:i1].copy_(-bf)
        cell.bias_ih[f0:f1].copy_(bf)

        cell.bias_hh[i0:i1].zero_()
        cell.bias_hh[f0:f1].zero_()

        if debug:
            print("[Chrono-LSTMCell] b_f (first 5):", bf[:5].cpu().numpy())
            print("[Chrono-LSTMCell] b_i (first 5):", (-bf)[:5].cpu().numpy())

    return cell

def apply_chrono_init_layernorm_lstm(model: nn.Module, Tmax: int, Tmin: float = 1.0, debug: bool = False) -> nn.Module:
    if not hasattr(model, "cells"):
        raise AttributeError("Expected model to have 'cells' (ModuleList of LSTMCell).")

    for i, cell in enumerate(model.cells):
        apply_chrono_init_lstmcell(cell, Tmax=Tmax, Tmin=Tmin, debug=(debug and i == 0))

    return model



